package com.ld.shieldsb.canalclient.handler.impl.db;

import java.io.Closeable;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;

import javax.sql.DataSource;

import com.ld.shieldsb.canalclient.util.SyncUtil;

import lombok.extern.slf4j.Slf4j;

/**
 * sql批量执行器
 *
 * @author rewerma 2018-11-7 下午06:45:49
 * @version 1.0.0
 */
@Slf4j
public class BatchExecutor implements Closeable {

    private DataSource dataSource;
    private Connection conn;
    private AtomicInteger idx = new AtomicInteger(0); // 线程安全
    private boolean useTransition = true;

    public BatchExecutor(DataSource dataSource) {
        this.dataSource = dataSource;
    }

    public BatchExecutor(DataSource dataSource, boolean useTransition) {
        this.dataSource = dataSource;
        this.useTransition = useTransition;
    }

    /**
     * 获取连接
     * 
     * @Title getConn
     * @author 吕凯
     * @date 2021年12月7日 上午10:50:55
     * @return Connection
     */
    public Connection getConn() {
        if (conn == null) {
            try {
                conn = dataSource.getConnection();
                if (useTransition) {
                    this.conn.setAutoCommit(false); // 设为不自动提交，执行完成批量提交
                }
            } catch (SQLException e) {
                log.error(e.getMessage(), e);
            }
        }
        return conn;
    }

    public static void setValue(List<Map<String, ?>> values, int type, Object value) {
        Map<String, Object> valueItem = new HashMap<>();
        valueItem.put("type", type);
        valueItem.put("value", value);
        values.add(valueItem);
    }

    public void execute(String sql, List<Map<String, ?>> values) throws SQLException {
        try (PreparedStatement pstmt = getConn().prepareStatement(sql);) {
            int len = values.size();
            for (int i = 0; i < len; i++) {
                int type = (Integer) values.get(i).get("type");
                Object value = values.get(i).get("value");
                SyncUtil.setPStmt(type, pstmt, value, i + 1);
            }
            log.warn("执行SQL：{}，参数：{}", sql, values);
            if (log.isDebugEnabled()) {
                log.debug("执行SQL：{}，参数：{}", sql, values);
            }
            pstmt.execute();
            idx.incrementAndGet();
        }
    }

    public void commit() throws SQLException {
        getConn().commit();
        if (log.isTraceEnabled()) {
            log.trace("批量提交 " + idx.get() + " 条数据");
        }
        idx.set(0);
    }

    public void rollback() throws SQLException {
        getConn().rollback();
        if (log.isTraceEnabled()) {
            log.trace("批量回滚 " + idx.get() + " 条数据");
        }
        idx.set(0);
    }

    @Override
    public void close() {
        if (conn != null) {
            try {
                conn.close();
            } catch (SQLException e) {
                log.error(e.getMessage(), e);
            } finally {
                conn = null;
            }
        }
    }
}
